{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-1vOMEXIhMQt"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "pRfq9ZU5hQhg"
      },
      "outputs": [],
      "source": [
        "#@title Copyright 2020 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mTL0TERThT6z"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/bert_experts\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/bert_experts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/hub/blob/master/examples/colab/bert_experts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/hub/examples/colab/bert_experts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://tfhub.dev/s?q=experts%2Fbert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub models\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FkthMlVk8bHp"
      },
      "source": [
        "# BERT Experts from TF-Hub\n",
        "\n",
        "This colab demonstrates how to:\n",
        "* Load BERT models from [TensorFlow Hub](https://tfhub.dev) that have been trained on different tasks including MNLI, SQuAD, and PubMed\n",
        "* Use a matching preprocessing model to tokenize raw text and convert it to ids\n",
        "* Generate the pooled and sequence output from the token input ids using the loaded model\n",
        "* Look at the semantic similarity of the pooled outputs of different sentences\n",
        "\n",
        "#### Note: This colab should be run with a GPU runtime"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jspO02jDPfPG"
      },
      "source": [
        "## Set up and imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r-ed8zj-dbwm"
      },
      "outputs": [],
      "source": [
        "!pip3 install --quiet tensorflow\n",
        "!pip3 install --quiet tensorflow_text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "czDmtrGKYw_5"
      },
      "outputs": [],
      "source": [
        "import seaborn as sns\n",
        "from sklearn.metrics import pairwise\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_text as text  # Imports TF ops for preprocessing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "GSuDcPSaY5aB"
      },
      "outputs": [],
      "source": [
        "#@title Configure the model { run: \"auto\" }\n",
        "BERT_MODEL = \"https://tfhub.dev/google/experts/bert/wiki_books/2\" # @param {type: \"string\"} [\"https://tfhub.dev/google/experts/bert/wiki_books/2\", \"https://tfhub.dev/google/experts/bert/wiki_books/mnli/2\", \"https://tfhub.dev/google/experts/bert/wiki_books/qnli/2\", \"https://tfhub.dev/google/experts/bert/wiki_books/qqp/2\", \"https://tfhub.dev/google/experts/bert/wiki_books/squad2/2\", \"https://tfhub.dev/google/experts/bert/wiki_books/sst2/2\",  \"https://tfhub.dev/google/experts/bert/pubmed/2\", \"https://tfhub.dev/google/experts/bert/pubmed/squad2/2\"]\n",
        "# Preprocessing must match the model, but all the above use the same.\n",
        "PREPROCESS_MODEL = \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pvaZiGVgwtqw"
      },
      "source": [
        "## Sentences\n",
        "\n",
        "Let's take some sentences from Wikipedia to run through model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tytu-rSpeDNG"
      },
      "outputs": [],
      "source": [
        "sentences = [\n",
        "  \"Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.\",\n",
        "  \"The album went straight to number one on the Norwegian album chart, and sold to double platinum.\",\n",
        "  \"Among the singles released from the album were the songs \\\"Be My Lover\\\" and \\\"Hard To Stay Awake\\\".\",\n",
        "  \"Riccardo Zegna is an Italian jazz musician.\",\n",
        "  \"Rajko Maksimović is a composer, writer, and music pedagogue.\",\n",
        "  \"One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.\",\n",
        "  \"Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum\",\n",
        "  \"A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.\",\n",
        "  \"A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth.\",\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zI39475kxCKh"
      },
      "source": [
        "## Run the model\n",
        "\n",
        "We'll load the BERT model from TF-Hub, tokenize our sentences using the matching preprocessing model from TF-Hub, then feed in the tokenized sentences to the model. To keep this colab fast and simple, we recommend running on GPU.\n",
        "\n",
        "Go to **Runtime** → **Change runtime type** to make sure that **GPU** is selected"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x4t6r22ErQg0"
      },
      "outputs": [],
      "source": [
        "preprocess = hub.load(PREPROCESS_MODEL)\n",
        "bert = hub.load(BERT_MODEL)\n",
        "inputs = preprocess(sentences)\n",
        "outputs = bert(inputs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gItjCg4315Cv"
      },
      "outputs": [],
      "source": [
        "print(\"Sentences:\")\n",
        "print(sentences)\n",
        "\n",
        "print(\"\\nBERT inputs:\")\n",
        "print(inputs)\n",
        "\n",
        "print(\"\\nPooled embeddings:\")\n",
        "print(outputs[\"pooled_output\"])\n",
        "\n",
        "print(\"\\nPer token embeddings:\")\n",
        "print(outputs[\"sequence_output\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ptiW2mgw6x-l"
      },
      "source": [
        "## Semantic similarity\n",
        "\n",
        "Now let's take a look at the `pooled_output` embeddings of our sentences and compare how similar they are across sentences."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "GXrSO2Vc1Qtr"
      },
      "outputs": [],
      "source": [
        "#@title Helper functions\n",
        "\n",
        "def plot_similarity(features, labels):\n",
        "  \"\"\"Plot a similarity matrix of the embeddings.\"\"\"\n",
        "  cos_sim = pairwise.cosine_similarity(features)\n",
        "  sns.set(font_scale=1.2)\n",
        "  cbar_kws=dict(use_gridspec=False, location=\"left\")\n",
        "  g = sns.heatmap(\n",
        "      cos_sim, xticklabels=labels, yticklabels=labels,\n",
        "      vmin=0, vmax=1, cmap=\"Blues\", cbar_kws=cbar_kws)\n",
        "  g.tick_params(labelright=True, labelleft=False)\n",
        "  g.set_yticklabels(labels, rotation=0)\n",
        "  g.set_title(\"Semantic Textual Similarity\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "td6jcT0pJMZ5"
      },
      "outputs": [],
      "source": [
        "plot_similarity(outputs[\"pooled_output\"], sentences)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tJ4QCyzhSL7B"
      },
      "source": [
        "## Learn more\n",
        "\n",
        "* Find more BERT models on [TensorFlow Hub](https://tfhub.dev)\n",
        "* This notebook demonstrates simple inference with BERT, you can find a more advanced tutorial about fine-tuning BERT at [tensorflow.org/official_models/fine_tuning_bert](https://www.tensorflow.org/official_models/fine_tuning_bert)\n",
        "* We used just one GPU chip to run the model, you can learn more about how to load models using tf.distribute at [tensorflow.org/tutorials/distribute/save_and_load](https://www.tensorflow.org/tutorials/distribute/save_and_load)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "bert_experts.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
